import json
from collections import Counter
import numpy as np
import re
import sklearn
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.metrics import accuracy_score, f1_score

def problem2prompt(problem_dict, intro:str):
    def add_test_to_prompt(query, function_names, test=None):
        problem_prompt = query + f" Your response should have the following function signature(s): {','.join(function_names)}. "
        if test is not None:
            problem_prompt += f"Additionally, your response should pass the following test: {test}. "
        return problem_prompt

    def extract_func_names_from_snippet(s):
        func_signatures = []
        for l in s.split("\n"):
            if l.startswith("def "):
                func_signatures.append(l.split(":")[0])
        return func_signatures

    function_names = extract_func_names_from_snippet(problem_dict['gold'])
    if 'tests' in problem_dict.keys():
        test = problem_dict['tests'][0] if len(problem_dict['tests']) > 0 else None
        problem_prompt = add_test_to_prompt(problem_dict["query"], function_names, test)
    else:
        problem_prompt = f" Your response should be a bash command."
    problem_prompt = intro + " Specifically, my problem is: " + problem_prompt
    return problem_prompt

with open('./intro.jsonl', 'r') as json_file:
        intro_list = list(json_file)
half_size = len(intro_list) // 2

with open("../data/python/mbpp/ic_mbpp.json") as f:
    problem_dicts = json.load(f)

train_intro_num = 20
generate_intro_num = 40
TEST_IDX_START = 10
TRAIN_IDX_START = 600
DataNum = 100
test_intro_num = generate_intro_num - train_intro_num
train_intro_none_expert = intro_list[:train_intro_num]
train_intro_expert = intro_list[generate_intro_num:generate_intro_num+train_intro_num]
test_intro_none_expert = intro_list[train_intro_num:train_intro_num+test_intro_num]
test_intro_expert = intro_list[generate_intro_num+train_intro_num:]
# print(train_intro_expert[0], train_intro_none_expert[0], test_intro_none_expert[0], test_intro_expert[0])
test_problem_dicts = problem_dicts[TEST_IDX_START:TEST_IDX_START+DataNum]
train_problem_dicts = problem_dicts[TRAIN_IDX_START:TRAIN_IDX_START+20]

train_expert, test_expert, train_none_expert, test_none_expert = [], [], [], []

for i in range(len(train_problem_dicts)):
    intro_expert = json.loads(train_intro_expert[i%train_intro_num])['text']
    intro_none_expert = json.loads(train_intro_none_expert[i%train_intro_num])['text']
    user_prompt = problem2prompt(train_problem_dicts[i], intro_expert)
    train_expert.append(re.sub(r'[^\w\s]','',user_prompt.lower().strip()))
    user_prompt = problem2prompt(train_problem_dicts[i], intro_none_expert)
    train_none_expert.append(re.sub(r'[^\w\s]','',user_prompt.lower().strip()))

for i in range(len(test_problem_dicts)):
    intro_expert = json.loads(test_intro_expert[i%test_intro_num])['text']
    intro_none_expert = json.loads(test_intro_none_expert[i%test_intro_num])['text']
    user_prompt = problem2prompt(test_problem_dicts[i], intro_expert)
    test_expert.append(re.sub(r'[^\w\s]','',user_prompt.lower().strip()))
    user_prompt = problem2prompt(test_problem_dicts[i], intro_none_expert)
    test_none_expert.append(re.sub(r'[^\w\s]','',user_prompt.lower().strip()))

# print(train_expert[0])
# print(train_none_expert[0])
# print(test_none_expert[0])
# print(test_expert[0])

# none_expert_list = [re.sub(r'[^\w\s]','',json.loads(di)['text'].lower().strip()) for di in intro_list[:half_size]]
# expert_list = [re.sub(r'[^\w\s]','',json.loads(di)['text'].lower().strip()) for di in intro_list[half_size:]]

def get_distribution(none_expert_list:list, expert_list:list, gram:int=1):
    tokens = set()
    for intro in none_expert_list+expert_list:
        words = intro.split(' ')
        # words = np.array(words)
        for idx in range(len(words)-gram+1):
            tokens.add(' '.join(words[idx:idx+gram]))
    print(f"{gram}-gram:")
    expert_distribution = {token:0 for token in tokens}
    none_expert_distribution = {token:0 for token in tokens}
    for intro in none_expert_list:
        words = intro.split(' ')
        words = np.array(words)
        for idx in range(len(words)-gram+1):
            none_expert_distribution[' '.join(words[idx:idx+gram])] += 1
    for intro in expert_list:
        words = intro.split(' ')
        words = np.array(words)
        for idx in range(len(words)-gram+1):
            expert_distribution[' '.join(words[idx:idx+gram])] += 1
    expert_distribution = dict(sorted(expert_distribution.items(), key=lambda item: -item[1]))
    none_expert_distribution = dict(sorted(none_expert_distribution.items(), key=lambda item: -item[1]))

    token_list = list(tokens)
    dif_dict = {pt:0 for pt in tokens}
    for t1 in token_list:
        c1, c2 = 0, 0
        for intro in none_expert_list:
            words = intro.split(' ')
            for idx in range(len(words)-gram+1):
                if t1 == ' '.join(words[idx:idx+gram]):
                    c1 += 1
                    break
        for intro in expert_list:
            words = intro.split(' ')
            for idx in range(len(words)-gram+1):
                if t1 == ' '.join(words[idx:idx+gram]):
                    c2 += 1
                    break
        dif_dict[t1] = round((c1 - c2) / len(expert_list), 2)
    # print(dif_dict)
    
    dif_dict = dict(sorted(dif_dict.items(), key=lambda item: -abs(item[1])))
    return expert_distribution, none_expert_distribution, dif_dict

def eval_dtm(test_expert, test_none_expert, dtm_dict):
    c1, c2 = 0, 0
    dtm = dtm_dict[0]
    for intro in test_none_expert:
        words = intro.split(' ')
        for idx in range(len(words)-gram+1):
            if dtm == ' '.join(words[idx:idx+gram]):
                c1 += 1
                break
    for intro in test_expert:
        words = intro.split(' ')
        for idx in range(len(words)-gram+1):
            if dtm == ' '.join(words[idx:idx+gram]):
                c2 += 1
                break
    tot = len(test_expert) + len(test_none_expert)
    if dtm_dict[1] > 0:
        return c1 / tot, (tot - c2) / tot
    else:
        return (tot - c1) / tot, c2 / tot

def regression(train_expert, test_expert, train_none_expert, test_none_expert, gram = 1):
    tokens = set()
    for intro in train_expert+test_expert+train_none_expert+test_none_expert:
        words = intro.split(' ')
        # words = np.array(words)
        for idx in range(len(words)-gram+1):
            tokens.add(' '.join(words[idx:idx+gram]))
    # print(f"{gram}-gram:")
    token_list = list(tokens)
    token2id = {token:j for j, token in enumerate(token_list)}
    train_feature = np.zeros((len(train_expert) + len(train_none_expert), len(token2id)))
    test_feature = np.zeros((len(test_expert) + len(test_none_expert), len(token2id)))
    fit_target = np.array([1]*len(train_none_expert) + [0]*len(train_expert))
    gold = np.array([1]*len(test_none_expert) + [0]*len(test_expert))
    for j, intro in enumerate(train_none_expert+train_expert):
        words = intro.split(' ')
        words = np.array(words)
        for idx in range(len(words)-gram+1):
            train_feature[j][token2id[' '.join(words[idx:idx+gram])]] = 1
    for j, intro in enumerate(test_none_expert+test_expert):
        words = intro.split(' ')
        words = np.array(words)
        for idx in range(len(words)-gram+1):
            test_feature[j][token2id[' '.join(words[idx:idx+gram])]] = 1
    model = LogisticRegression().fit(train_feature, fit_target)
    # print(len(token2id), model.coef_.argsort().shape)
    print(np.array(token_list)[model.coef_.argsort()[0][-10:][::-1]])
    print(model.coef_[0][model.coef_.argsort()[0][-10:][::-1]])
    print(np.array(token_list)[model.coef_.argsort()[0][:10]])
    print(model.coef_[0][model.coef_.argsort()[0][:10]])

    max_arg = model.coef_.argmax()
    print(np.array(token_list)[model.coef_.argmax()])
    # print(test_feature[:, max_arg])
    preds = model.predict(test_feature)
    # print(preds)
    # print(preds.shape)
    # print(preds, gold)
    acc = accuracy_score(preds, gold)
    f1 = f1_score(preds, gold)
    return acc, f1

ll = [1,2,3,4]
for gram in ll:
    expert_distribution, none_expert_distribution, dif_dict = get_distribution(train_expert, train_none_expert, gram=gram)
    dtm = list(dif_dict.items())[0]
    print(dtm)
    print(list(dif_dict.items())[:5])
    acc_none, acc = eval_dtm(test_expert, test_none_expert, dtm)
    print(f"Accuracy: {(acc_none+acc)/2}  none_expert_acc: {acc_none}, expert_acc: {acc}")
    acc, f1 = regression(train_expert, test_expert, train_none_expert, test_none_expert, gram=gram)
    print(f"regression acc: {acc}   F1-score: {f1}")
    # print()

# print(expert_distribution)
# print(none_expert_distribution)

